Skip to content

Conversation

@ErwinTerpstra
Copy link
Contributor

@ErwinTerpstra ErwinTerpstra commented Jan 19, 2026

Proposed changes

Support for packed 4-bit floating point for both A and B tensors in block scale gemm. Tested with A using 1D block scale and B using 2D block scale. Works for both the "regular" and Preshuffle-B pipelines. Note that the regular pipeline stores data in fp8 in LDS (as this is how int4 was implemented). The WP pipeline stores tensor A in fp4 in LDS and dequants in when loading to registers.

Changes include:

  • Add fp4 support to ABQuant example, with/without PreshuffleB
  • Tests for fp4 on both A and B tensors (a4w4) for base case, irregular sizes and preshuffle B pipeline.
  • Other changes:
    • Add support to InterleavedPKTypeLoader for generic type conversions instead of just int4
    • Add LUT for converting fp4 to fp8. Improves performance of 4K tensor by around 25% on gfx12. Disabled by default using TEST_convert_with_table.
    • Some helper traits to work with packed or mixed precision types. Including a method to determine MFMA type based on input types

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@ErwinTerpstra ErwinTerpstra changed the title Support for a4w4 (fp4) in block scale gemm AB quant [CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant Jan 21, 2026
@krithalith krithalith requested a review from ex-rzr January 21, 2026 12:40
using LargestInputType = largest_type_t<ADataType, BDataType>;
if constexpr(is_packed_type_v<LargestInputType>)
{
return t<fp8_t>{};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't expect this to change anytime soon, but for maintainability reasons I'd consider adding a:

static_assert(sizeof(typename LargestInputType::type) == sizeof(fp8_t));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

{
const BDataType pk_val = b_element_op(b_k_n(index));
const fp32x2_t fp32_val = pk_val.to_fp32x2();
self(index) = (index[0] & 1) ? fp32_val.hi : fp32_val.lo;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a_acc you do (index[1] & 1) and for b_acc you do (index[0] & 1). The reason is not apparent immediately and the removed hunk always did (k & 1). As you've explained to me this is because A is MxK and B is KxN.

You may want to add a comment explaining it or -even better- make the code self-explanatory by doing something like

constexpr auto A_TENSOR_K_DIM = 1;
constexpr auto B_TENSOR_K_DIM = 0;
(index[A_TENSOR_K_DIM] & 1)
(index[B_TENSOR_K_DIM] & 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
if constexpr(numeric_traits<SrcDataType>::PackedSize > 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use is_packed_type_v here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: I'm thinking whether it would make more sense to rename this to PrecomputedComputeDataType because compute is a verb and thus makes me think of a function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or ComputationDataType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand where you are coming from, but ComputeDataType is the existing convention for the MFMA input type in CK/CK Tile.

abquant_quantgrouped_fp4_instance_factory(lut);
abquant_quantgrouped_fp8_instance_factory(lut);
abquant_quantgrouped_bf8_instance_factory(lut);
abquant_quantgrouped_preshuffleb_fp4_instance_factory(lut);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this and the non-preshuffleb variant be in the same file/function like we do on fp8 and bf8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split them specifically since the preshuffleb pipeline is really slow to compile. This way it can already start compiling simultaneously with the other instances, and we don't extend compile times by a single translation unit taking longer than necessary. Other instances (e.g. bquant instances) the preshuffleb are also split.

So for consistency actually we could also split fp8/bf8 instances to a preshuffle-specific file

std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;

// Calculate and display reference timing
using DurationType = std::chrono::duration<double>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could directly use std::chrono::milliseconds

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would only give millisecond precision right? (Not that this does crucial timing, but I do look at 0.1ms precision)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants